# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pytype: skip-file

import bz2
import gzip
import io
import logging
import math
import os
import random
import tempfile
import unittest

import hamcrest as hc

import apache_beam as beam
from apache_beam.io import filebasedsource
from apache_beam.io import iobase
from apache_beam.io import range_trackers
# importing following private classes for testing
from apache_beam.io.concat_source import ConcatSource
from apache_beam.io.filebasedsource import _SingleFileSource as SingleFileSource
from apache_beam.io.filebasedsource import FileBasedSource
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.options.value_provider import StaticValueProvider
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher


class LineSource(FileBasedSource):
  def read_records(self, file_name, range_tracker):
    f = self.open_file(file_name)
    try:
      start = range_tracker.start_position()
      if start > 0:
        # Any line that starts after 'start' does not belong to the current
        # bundle. Seeking to (start - 1) and skipping a line moves the current
        # position to the starting position of the first line that belongs to
        # the current bundle.
        start -= 1
        f.seek(start)
        line = f.readline()
        start += len(line)
      current = start
      line = f.readline()
      while range_tracker.try_claim(current):
        # When the source is unsplittable, try_claim is not enough to determine
        # whether the file has reached to the end.
        if not line:
          return
        yield line.rstrip(b'\n')
        current += len(line)
        line = f.readline()
    finally:
      f.close()


class EOL(object):
  LF = 1
  CRLF = 2
  MIXED = 3
  LF_WITH_NOTHING_AT_LAST_LINE = 4


def write_data(
    num_lines,
    no_data=False,
    directory=None,
    prefix=tempfile.template,
    eol=EOL.LF):
  """Writes test data to a temporary file.

  Args:
    num_lines (int): The number of lines to write.
    no_data (bool): If :data:`True`, empty lines will be written, otherwise
      each line will contain a concatenation of b'line' and the line number.
    directory (str): The name of the directory to create the temporary file in.
    prefix (str): The prefix to use for the temporary file.
    eol (int): The line ending to use when writing.
      :class:`~apache_beam.io.filebasedsource_test.EOL` exposes attributes that
      can be used here to define the eol.

  Returns:
    Tuple[str, List[bytes]]: A tuple of the filename and a list of the written
      data.
  """
  all_data = []
  with tempfile.NamedTemporaryFile(delete=False, dir=directory,
                                   prefix=prefix) as f:
    sep_values = [b'\n', b'\r\n']
    for i in range(num_lines):
      data = b'' if no_data else b'line' + str(i).encode()
      all_data.append(data)

      if eol == EOL.LF:
        sep = sep_values[0]
      elif eol == EOL.CRLF:
        sep = sep_values[1]
      elif eol == EOL.MIXED:
        sep = sep_values[i % len(sep_values)]
      elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE:
        sep = b'' if i == (num_lines - 1) else sep_values[0]
      else:
        raise ValueError('Received unknown value %s for eol.' % eol)

      f.write(data + sep)

    return f.name, all_data


def _write_prepared_data(
    data, directory=None, prefix=tempfile.template, suffix=''):
  with tempfile.NamedTemporaryFile(delete=False,
                                   dir=directory,
                                   prefix=prefix,
                                   suffix=suffix) as f:
    f.write(data)
    return f.name


def write_prepared_pattern(data, suffixes=None):
  assert data, 'Data (%s) seems to be empty' % data
  if suffixes is None:
    suffixes = [''] * len(data)
  temp_dir = tempfile.mkdtemp()
  for i, d in enumerate(data):
    file_name = _write_prepared_data(
        d, temp_dir, prefix='mytemp', suffix=suffixes[i])
  return file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*'


def write_pattern(lines_per_file, no_data=False):
  """Writes a pattern of temporary files.

  Args:
    lines_per_file (List[int]): The number of lines to write per file.
    no_data (bool): If :data:`True`, empty lines will be written, otherwise
      each line will contain a concatenation of b'line' and the line number.

  Returns:
    Tuple[str, List[bytes]]: A tuple of the filename pattern and a list of the
      written data.
  """
  temp_dir = tempfile.mkdtemp()

  all_data = []
  file_name = None
  start_index = 0
  for i in range(len(lines_per_file)):
    file_name, data = write_data(lines_per_file[i], no_data=no_data,
                                 directory=temp_dir, prefix='mytemp')
    all_data.extend(data)
    start_index += lines_per_file[i]

  assert file_name
  return (
      file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*',
      all_data)


class TestConcatSource(unittest.TestCase):
  class DummySource(iobase.BoundedSource):
    def __init__(self, values):
      self._values = values

    def split(
        self, desired_bundle_size, start_position=None, stop_position=None):
      # simply devides values into two bundles
      middle = len(self._values) // 2
      yield iobase.SourceBundle(
          0.5, TestConcatSource.DummySource(self._values[:middle]), None, None)
      yield iobase.SourceBundle(
          0.5, TestConcatSource.DummySource(self._values[middle:]), None, None)

    def get_range_tracker(self, start_position, stop_position):
      if start_position is None:
        start_position = 0
      if stop_position is None:
        stop_position = len(self._values)

      return range_trackers.OffsetRangeTracker(start_position, stop_position)

    def read(self, range_tracker):
      for index, value in enumerate(self._values):
        if not range_tracker.try_claim(index):
          return

        yield value

    def estimate_size(self):
      return len(self._values)  # Assuming each value to be 1 byte.

  def setUp(self):
    # Reducing the size of thread pools. Without this test execution may fail in
    # environments with limited amount of resources.
    filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

  def test_read(self):
    sources = [
        TestConcatSource.DummySource(range(start, start + 10))
        for start in [0, 10, 20]
    ]
    concat = ConcatSource(sources)
    range_tracker = concat.get_range_tracker(None, None)
    read_data = [value for value in concat.read(range_tracker)]
    self.assertCountEqual(list(range(30)), read_data)

  def test_split(self):
    sources = [
        TestConcatSource.DummySource(list(range(start, start + 10)))
        for start in [0, 10, 20]
    ]
    concat = ConcatSource(sources)
    splits = [split for split in concat.split()]
    self.assertEqual(6, len(splits))

    # Reading all splits
    read_data = []
    for split in splits:
      range_tracker_for_split = split.source.get_range_tracker(
          split.start_position, split.stop_position)
      read_data.extend(
          [value for value in split.source.read(range_tracker_for_split)])
    self.assertCountEqual(list(range(30)), read_data)

  def test_estimate_size(self):
    sources = [
        TestConcatSource.DummySource(range(start, start + 10))
        for start in [0, 10, 20]
    ]
    concat = ConcatSource(sources)
    self.assertEqual(30, concat.estimate_size())


class TestFileBasedSource(unittest.TestCase):
  def setUp(self):
    # Reducing the size of thread pools. Without this test execution may fail in
    # environments with limited amount of resources.
    filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

  def test_string_or_value_provider_only(self):
    str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
    self.assertEqual(
        str_file_pattern, FileBasedSource(str_file_pattern)._pattern.value)

    static_vp_file_pattern = StaticValueProvider(
        value_type=str, value=str_file_pattern)
    self.assertEqual(
        static_vp_file_pattern,
        FileBasedSource(static_vp_file_pattern)._pattern)

    runtime_vp_file_pattern = RuntimeValueProvider(
        option_name='arg', value_type=str, default_value=str_file_pattern)
    self.assertEqual(
        runtime_vp_file_pattern,
        FileBasedSource(runtime_vp_file_pattern)._pattern)
    # Reset runtime options to avoid side-effects in other tests.
    RuntimeValueProvider.set_runtime_options(None)

    invalid_file_pattern = 123
    with self.assertRaises(TypeError):
      FileBasedSource(invalid_file_pattern)

  def test_validation_file_exists(self):
    file_name, _ = write_data(10)
    LineSource(file_name)

  def test_validation_directory_non_empty(self):
    temp_dir = tempfile.mkdtemp()
    file_name, _ = write_data(10, directory=temp_dir)
    LineSource(file_name)

  def test_validation_failing(self):
    no_files_found_error = 'No files found based on the file pattern*'
    with self.assertRaisesRegex(IOError, no_files_found_error):
      LineSource('dummy_pattern')
    with self.assertRaisesRegex(IOError, no_files_found_error):
      temp_dir = tempfile.mkdtemp()
      LineSource(os.path.join(temp_dir, '*'))

  def test_validation_file_missing_verification_disabled(self):
    LineSource('dummy_pattern', validate=False)

  def test_fully_read_single_file(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    fbs = LineSource(file_name)
    range_tracker = fbs.get_range_tracker(None, None)
    read_data = [record for record in fbs.read(range_tracker)]
    self.assertCountEqual(expected_data, read_data)

  def test_single_file_display_data(self):
    file_name, _ = write_data(10)
    fbs = LineSource(file_name)
    dd = DisplayData.create_from(fbs)
    expected_items = [
        DisplayDataItemMatcher('file_pattern', file_name),
        DisplayDataItemMatcher('compression', 'auto')
    ]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

  def test_fully_read_file_pattern(self):
    pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
    assert len(expected_data) == 40
    fbs = LineSource(pattern)
    range_tracker = fbs.get_range_tracker(None, None)
    read_data = [record for record in fbs.read(range_tracker)]
    self.assertCountEqual(expected_data, read_data)

  def test_fully_read_file_pattern_with_empty_files(self):
    pattern, expected_data = write_pattern([5, 0, 12, 0, 8, 0])
    assert len(expected_data) == 25
    fbs = LineSource(pattern)
    range_tracker = fbs.get_range_tracker(None, None)
    read_data = [record for record in fbs.read(range_tracker)]
    self.assertCountEqual(expected_data, read_data)

  def test_estimate_size_of_file(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    fbs = LineSource(file_name)
    self.assertEqual(10 * 6, fbs.estimate_size())

  def test_estimate_size_of_pattern(self):
    pattern, expected_data = write_pattern([5, 3, 10, 8, 8, 4])
    assert len(expected_data) == 38
    fbs = LineSource(pattern)
    self.assertEqual(38 * 6, fbs.estimate_size())

    pattern, expected_data = write_pattern([5, 3, 9])
    assert len(expected_data) == 17
    fbs = LineSource(pattern)
    self.assertEqual(17 * 6, fbs.estimate_size())

  def test_estimate_size_with_sampling_same_size(self):
    num_files = 2 * FileBasedSource.MIN_NUMBER_OF_FILES_TO_STAT
    pattern, _ = write_pattern([10] * num_files)
    # Each line will be of length 6 since write_pattern() uses
    # ('line' + line number + '\n') as data.
    self.assertEqual(
        6 * 10 * num_files, FileBasedSource(pattern).estimate_size())

  def test_estimate_size_with_sampling_different_sizes(self):
    num_files = 2 * FileBasedSource.MIN_NUMBER_OF_FILES_TO_STAT

    # Each line will be of length 8 since write_pattern() uses
    # ('line' + line number + '\n') as data.
    base_size = 500
    variance = 5

    sizes = []
    for _ in range(num_files):
      sizes.append(
          int(random.uniform(base_size - variance, base_size + variance)))
    pattern, _ = write_pattern(sizes)
    tolerance = 0.05
    self.assertAlmostEqual(
        base_size * 8 * num_files,
        FileBasedSource(pattern).estimate_size(),
        delta=base_size * 8 * num_files * tolerance)

  def test_splits_into_subranges(self):
    pattern, expected_data = write_pattern([5, 9, 6])
    assert len(expected_data) == 20
    fbs = LineSource(pattern)
    splits = [split for split in fbs.split(desired_bundle_size=15)]
    expected_num_splits = (
        math.ceil(float(6 * 5) / 15) + math.ceil(float(6 * 9) / 15) +
        math.ceil(float(6 * 6) / 15))
    assert len(splits) == expected_num_splits

  def test_read_splits_single_file(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    fbs = LineSource(file_name)
    splits = [split for split in fbs.split(desired_bundle_size=33)]

    # Reading all splits
    read_data = []
    for split in splits:
      source = split.source
      range_tracker = source.get_range_tracker(
          split.start_position, split.stop_position)
      data_from_split = [data for data in source.read(range_tracker)]
      read_data.extend(data_from_split)

    self.assertCountEqual(expected_data, read_data)

  def test_read_splits_file_pattern(self):
    pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
    assert len(expected_data) == 200
    fbs = LineSource(pattern)
    splits = [split for split in fbs.split(desired_bundle_size=50)]

    # Reading all splits
    read_data = []
    for split in splits:
      source = split.source
      range_tracker = source.get_range_tracker(
          split.start_position, split.stop_position)
      data_from_split = [data for data in source.read(range_tracker)]
      read_data.extend(data_from_split)

    self.assertCountEqual(expected_data, read_data)

  def _run_source_test(self, pattern, expected_data, splittable=True):
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(pattern, splittable=splittable))
      assert_that(pcoll, equal_to(expected_data))

  def test_source_file(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    self._run_source_test(file_name, expected_data)

  def test_source_pattern(self):
    pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
    assert len(expected_data) == 200
    self._run_source_test(pattern, expected_data)

  def test_unsplittable_does_not_split(self):
    pattern, expected_data = write_pattern([5, 9, 6])
    assert len(expected_data) == 20
    fbs = LineSource(pattern, splittable=False)
    splits = [split for split in fbs.split(desired_bundle_size=15)]
    self.assertEqual(3, len(splits))

  def test_source_file_unsplittable(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    self._run_source_test(file_name, expected_data, False)

  def test_source_pattern_unsplittable(self):
    pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
    assert len(expected_data) == 200
    self._run_source_test(pattern, expected_data, False)

  def test_read_file_bzip2(self):
    _, lines = write_data(10)
    filename = tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template).name
    with bz2.BZ2File(filename, 'wb') as f:
      f.write(b'\n'.join(lines))

    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(
              filename,
              splittable=False,
              compression_type=CompressionTypes.BZIP2))
      assert_that(pcoll, equal_to(lines))

  def test_read_file_gzip(self):
    _, lines = write_data(10)
    filename = tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template).name
    with gzip.GzipFile(filename, 'wb') as f:
      f.write(b'\n'.join(lines))

    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(
              filename,
              splittable=False,
              compression_type=CompressionTypes.GZIP))
      assert_that(pcoll, equal_to(lines))

  def test_read_pattern_bzip2(self):
    _, lines = write_data(200)
    splits = [0, 34, 100, 140, 164, 188, 200]
    chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
    compressed_chunks = []
    for c in chunks:
      compressobj = bz2.BZ2Compressor()
      compressed_chunks.append(
          compressobj.compress(b'\n'.join(c)) + compressobj.flush())
    file_pattern = write_prepared_pattern(compressed_chunks)
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(
              file_pattern,
              splittable=False,
              compression_type=CompressionTypes.BZIP2))
      assert_that(pcoll, equal_to(lines))

  def test_read_pattern_gzip(self):
    _, lines = write_data(200)
    splits = [0, 34, 100, 140, 164, 188, 200]
    chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
    compressed_chunks = []
    for c in chunks:
      out = io.BytesIO()
      with gzip.GzipFile(fileobj=out, mode="wb") as f:
        f.write(b'\n'.join(c))
      compressed_chunks.append(out.getvalue())
    file_pattern = write_prepared_pattern(compressed_chunks)
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(
              file_pattern,
              splittable=False,
              compression_type=CompressionTypes.GZIP))
      assert_that(pcoll, equal_to(lines))

  def test_read_auto_single_file_bzip2(self):
    _, lines = write_data(10)
    filename = tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template, suffix='.bz2').name
    with bz2.BZ2File(filename, 'wb') as f:
      f.write(b'\n'.join(lines))

    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(filename, compression_type=CompressionTypes.AUTO))
      assert_that(pcoll, equal_to(lines))

  def test_read_auto_single_file_gzip(self):
    _, lines = write_data(10)
    filename = tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template, suffix='.gz').name
    with gzip.GzipFile(filename, 'wb') as f:
      f.write(b'\n'.join(lines))

    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(filename, compression_type=CompressionTypes.AUTO))
      assert_that(pcoll, equal_to(lines))

  def test_read_auto_pattern(self):
    _, lines = write_data(200)
    splits = [0, 34, 100, 140, 164, 188, 200]
    chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
    compressed_chunks = []
    for c in chunks:
      out = io.BytesIO()
      with gzip.GzipFile(fileobj=out, mode="wb") as f:
        f.write(b'\n'.join(c))
      compressed_chunks.append(out.getvalue())
    file_pattern = write_prepared_pattern(
        compressed_chunks, suffixes=['.gz'] * len(chunks))
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(file_pattern, compression_type=CompressionTypes.AUTO))
      assert_that(pcoll, equal_to(lines))

  def test_read_auto_pattern_compressed_and_uncompressed(self):
    _, lines = write_data(200)
    splits = [0, 34, 100, 140, 164, 188, 200]
    chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
    chunks_to_write = []
    for i, c in enumerate(chunks):
      if i % 2 == 0:
        out = io.BytesIO()
        with gzip.GzipFile(fileobj=out, mode="wb") as f:
          f.write(b'\n'.join(c))
        chunks_to_write.append(out.getvalue())
      else:
        chunks_to_write.append(b'\n'.join(c))
    file_pattern = write_prepared_pattern(
        chunks_to_write, suffixes=(['.gz', ''] * 3))
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> beam.io.Read(
          LineSource(file_pattern, compression_type=CompressionTypes.AUTO))
      assert_that(pcoll, equal_to(lines))

  def test_splits_get_coder_from_fbs(self):
    class DummyCoder(object):
      val = 12345

    class FileBasedSourceWithCoder(LineSource):
      def default_output_coder(self):
        return DummyCoder()

    pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
    self.assertEqual(200, len(expected_data))
    fbs = FileBasedSourceWithCoder(pattern)
    splits = [split for split in fbs.split(desired_bundle_size=50)]
    self.assertTrue(len(splits))
    for split in splits:
      self.assertEqual(DummyCoder.val, split.source.default_output_coder().val)


class TestSingleFileSource(unittest.TestCase):
  def setUp(self):
    # Reducing the size of thread pools. Without this test execution may fail in
    # environments with limited amount of resources.
    filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2

  def test_source_creation_fails_for_non_number_offsets(self):
    start_not_a_number_error = 'start_offset must be a number*'
    stop_not_a_number_error = 'stop_offset must be a number*'
    file_name = 'dummy_pattern'
    fbs = LineSource(file_name, validate=False)

    with self.assertRaisesRegex(TypeError, start_not_a_number_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset='aaa', stop_offset='bbb')
    with self.assertRaisesRegex(TypeError, start_not_a_number_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset='aaa', stop_offset=100)
    with self.assertRaisesRegex(TypeError, stop_not_a_number_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset=100, stop_offset='bbb')
    with self.assertRaisesRegex(TypeError, stop_not_a_number_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset=100, stop_offset=None)
    with self.assertRaisesRegex(TypeError, start_not_a_number_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset=None, stop_offset=100)

  def test_source_creation_display_data(self):
    file_name = 'dummy_pattern'
    fbs = LineSource(file_name, validate=False)
    dd = DisplayData.create_from(fbs)
    expected_items = [
        DisplayDataItemMatcher('compression', 'auto'),
        DisplayDataItemMatcher('file_pattern', file_name)
    ]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

  def test_source_creation_fails_if_start_lg_stop(self):
    start_larger_than_stop_error = (
        'start_offset must be smaller than stop_offset*')
    fbs = LineSource('dummy_pattern', validate=False)
    SingleFileSource(
        fbs, file_name='dummy_file', start_offset=99, stop_offset=100)
    with self.assertRaisesRegex(ValueError, start_larger_than_stop_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset=100, stop_offset=99)
    with self.assertRaisesRegex(ValueError, start_larger_than_stop_error):
      SingleFileSource(
          fbs, file_name='dummy_file', start_offset=100, stop_offset=100)

  def test_estimates_size(self):
    fbs = LineSource('dummy_pattern', validate=False)

    # Should simply return stop_offset - start_offset
    source = SingleFileSource(
        fbs, file_name='dummy_file', start_offset=0, stop_offset=100)
    self.assertEqual(100, source.estimate_size())

    source = SingleFileSource(
        fbs, file_name='dummy_file', start_offset=10, stop_offset=100)
    self.assertEqual(90, source.estimate_size())

  def test_read_range_at_beginning(self):
    fbs = LineSource('dummy_pattern', validate=False)

    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10

    source = SingleFileSource(fbs, file_name, 0, 10 * 6)
    range_tracker = source.get_range_tracker(0, 20)
    read_data = [value for value in source.read(range_tracker)]
    self.assertCountEqual(expected_data[:4], read_data)

  def test_read_range_at_end(self):
    fbs = LineSource('dummy_pattern', validate=False)

    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10

    source = SingleFileSource(fbs, file_name, 0, 10 * 6)
    range_tracker = source.get_range_tracker(40, 60)
    read_data = [value for value in source.read(range_tracker)]
    self.assertCountEqual(expected_data[-3:], read_data)

  def test_read_range_at_middle(self):
    fbs = LineSource('dummy_pattern', validate=False)

    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10

    source = SingleFileSource(fbs, file_name, 0, 10 * 6)
    range_tracker = source.get_range_tracker(20, 40)
    read_data = [value for value in source.read(range_tracker)]
    self.assertCountEqual(expected_data[4:7], read_data)

  def test_produces_splits_desiredsize_large_than_size(self):
    fbs = LineSource('dummy_pattern', validate=False)

    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = SingleFileSource(fbs, file_name, 0, 10 * 6)
    splits = [split for split in source.split(desired_bundle_size=100)]
    self.assertEqual(1, len(splits))
    self.assertEqual(60, splits[0].weight)
    self.assertEqual(0, splits[0].start_position)
    self.assertEqual(60, splits[0].stop_position)

    range_tracker = splits[0].source.get_range_tracker(None, None)
    read_data = [value for value in splits[0].source.read(range_tracker)]
    self.assertCountEqual(expected_data, read_data)

  def test_produces_splits_desiredsize_smaller_than_size(self):
    fbs = LineSource('dummy_pattern', validate=False)

    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = SingleFileSource(fbs, file_name, 0, 10 * 6)
    splits = [split for split in source.split(desired_bundle_size=25)]
    self.assertEqual(3, len(splits))

    read_data = []
    for split in splits:
      source = split.source
      range_tracker = source.get_range_tracker(
          split.start_position, split.stop_position)
      data_from_split = [data for data in source.read(range_tracker)]
      read_data.extend(data_from_split)
    self.assertCountEqual(expected_data, read_data)

  def test_produce_split_with_start_and_end_positions(self):
    fbs = LineSource('dummy_pattern', validate=False)

    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = SingleFileSource(fbs, file_name, 0, 10 * 6)
    splits = [
        split for split in source.split(
            desired_bundle_size=15, start_offset=10, stop_offset=50)
    ]
    self.assertEqual(3, len(splits))

    read_data = []
    for split in splits:
      source = split.source
      range_tracker = source.get_range_tracker(
          split.start_position, split.stop_position)
      data_from_split = [data for data in source.read(range_tracker)]
      read_data.extend(data_from_split)
    self.assertCountEqual(expected_data[2:9], read_data)


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  unittest.main()
